In [1]:
import numpy as np

A compact fixed-point neuron model

This code implements a complete fixed-point neuron, including both soma and synapse, using 32 bits of memory per neuron. These 32 bits are allocated as:

  • 6 bits for membrane voltage
  • 2 bits for refractory period
  • 8 bits for bias input current
  • 16 bits for synapse state

The pool of neurons has a fixed soma membrane time constant of 16*dt. Thre refractory period is a fixed multiple of dt (0, 1, 2, or 3). All synapses share the same time constant tau_syn (which must be of the form dt*(2**N)), and the same synaptic stregth weight_syn (which must be a power of 2).

Note that the bias values do not change while the model is running, so only 24 bits of state are actively written to.


In [89]:
class CompactPool:
    def __init__(self, n_neurons, dt=0.001, tau_syn=0.008, weight_syn=8.0, refractory_steps=2):
        self.n_neurons = n_neurons
        self.bias = 0
        self.dt = dt        
        
        self.weight_shift = 0
        self.weight_syn = weight_syn
        while (1 << self.weight_shift) < weight_syn:
            self.weight_shift += 1
            
        self.refractory_steps = refractory_steps
        assert 0 <= refractory_steps <= 3

        self.state = np.zeros(n_neurons, dtype='uint32')
        # bits  5- 0: voltage
        # bits  7- 6: refractory
        # bits 15- 8: bias
        # bits 31-16: current

        decay_shift = 0
        while (1 << decay_shift) * dt < tau_syn:
            decay_shift += 1
        self.decay_shift = decay_shift

    def set_bias(self, bias):
        bias = (bias * 2).astype('i32')
        bias = bias & 0x000000FF
        self.state &= 0xFFFF00FF
        self.state |= bias << 8

    def get_bias(self):
        bias = ((self.state >> 8) & 0x00000FF).astype('i32')
        bias[bias > 0x7F] -= 0x100  # handle sign on current
        return bias.astype(float) /2.0


    def get_voltage(self):
        voltage = self.state & 0x0000003F
        return voltage.astype(float) / 0x40

    def get_syn_current(self):
        current = ((self.state >> 16) & 0xFFFF).astype('i32')
        current[current > 0x7FFF] -= 0x10000  # handle sign on current
        return current.astype(float) / (1 << 11)

    def step(self, spikes):
        # extract data out of the state
        voltage = (self.state & 0x0000003F).astype('i32')
        refractory = (self.state >> 6) & 0x0000003
        bias = ((self.state >> 8) & 0x00000FF).astype('i32')
        bias[bias > 0x7F] -= 0x100  # handle sign on current
        current = ((self.state >> 16) & 0xFFFF).astype('i32')
        current[current > 0x7FFF] -= 0x10000  # handle sign on current

        current = current << 5
        # synaptic decay
        decay = current >> self.decay_shift
        current -= decay
        current[decay==0] = 0

        current = current >> 5

        # add spike
        current += spikes << (self.weight_shift + 11 - self.decay_shift)

        current[current < -0x8000] = -0x8000
        current[current > 0x7FFF] = 0x7FFF


        total_current = current + (bias << 10)
        # soma update
        rc_shift = 4  # tau_rc = dt * 2**4 = 0.016
        dv = (((total_current >> 5) - voltage) >> rc_shift)

        # no voltage change during refractory period
        dv[refractory > 0] = 0
        refractory[refractory > 0] -= 1

        # update voltage
        voltage = voltage + dv
        voltage[voltage < 0] = 0

        # detect spikes
        spiked = voltage >= 0x40
        refractory[spiked > 0] = self.refractory_steps
        voltage[spiked > 0] -= 0x40
        #voltage[voltage >= 0x40] = 0x3F
        # make sure we're not driving it so hard it spikes twice
        #assert np.sum(voltage >= 0x40) == 0
        voltage[voltage >= 0x40] == 0x3F

        # put data back into state
        self.state &= 0x0000FF00
        self.state |= (current << 16)# & 0xFFFF000
        self.state |= (refractory << 6)
        self.state |= voltage

        return spiked

Quick function to generate poisson spikes for input


In [87]:
def poisson_spikes(rng, rate, dt=0.001):
    """Generates one time step of a poisson spike train."""
    sign = np.where(rate > 0, 1, -1)
    rate = rate * sign
    rate[rate<=0.00001] = 0.00001

    time = np.zeros(rate.shape)
    spikes = np.zeros(rate.shape, dtype='i32')
    time = -np.log(rng.rand(*rate.shape)) / rate
    index = np.where(time < dt)
    spikes[index] += sign[index]
    while len(index[0]) > 0:
        time += -np.log(rng.rand(*rate.shape)) / rate
        index = np.where(time < dt)
        spikes[index] += sign[index]
    return spikes

Run the model with different rates on inputs to each neuron, to get a response curve


In [71]:
def compute_response(pool, max_input_rate, T, seed=None):
    rates = np.linspace(0, max_input_rate, n_neurons)

    rng = np.random.RandomState(seed=seed)

    
    spike_count = np.zeros(n_neurons)
    for i in range(int(T/dt)):
        input = poisson_spikes(rng, rates, dt=dt)
        spike_count += pool.step(input)
    
    response_curve = spike_count.astype(float)/T
    return rates, response_curve

In [96]:
dt = 0.001
n_neurons = 100
max_input_rate = 1000
tau_syn = 0.008 # must be dt * (a power of 2)
weight_syn = 8.0  # must be a power of 2
refractory_steps = 2 # 0 to 3
pool = CompactPool(n_neurons=n_neurons, dt=dt, weight_syn=weight_syn,
                   refractory_steps = refractory_steps, tau_syn=tau_syn)
T = 5.0   # time to simulate for


rates, response_curve = compute_response(pool, max_input_rate, T)

In [97]:
xlabel('input spike rate (Hz)')
ylabel('output spike rate (Hz)')
plot(rates, response_curve)
show()


Now let's see the effects of adjusting the bias current


In [98]:
dt = 0.001
n_neurons = 100
max_input_rate = 1000
tau_syn = 0.008 # must be dt * (a power of 2)
weight_syn = 8.0  # must be a power of 2
refractory_steps = 2 # 0 to 3
pool = CompactPool(n_neurons=n_neurons, dt=dt, weight_syn=weight_syn,
                   refractory_steps = refractory_steps, tau_syn=tau_syn)
T = 5.0   # time to simulate for

bias = np.linspace(-10, 10, 21)

for b in bias:
    pool.set_bias(np.ones(n_neurons).astype(float) * b)
    rates, response_curve = compute_response(pool, max_input_rate, T)
    plot(rates, response_curve)
    
xlabel('input spike rate (Hz)')
ylabel('output spike rate (Hz)')    
show()



In [ ]: